import argparse
import json
import os
import sys
# Add the parent directory of `src` to sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from src.experiment import Experiment

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

if __name__ == '__main__':
    # Set up argument parser
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, required=True)
    parser.add_argument('--exp_name', type=str)
    parser.add_argument('--n_iterations', type=int, default=None)
    parser.add_argument('--bs', type=int, default=None)
    parser.add_argument('--traj_length', type=int, default=None)
    parser.add_argument('--loss', type=str, default=None)
    parser.add_argument('--tie_weights', type=str2bool, default=None)
    parser.add_argument('--metad', type=str2bool, default=None)
    parser.add_argument('--epsilon_noisy', type=str2bool, default=None)
    parser.add_argument('--noise_exploration', type=str2bool, default=None)
    parser.add_argument('--replay_buffer', type=str2bool, default=None)
    parser.add_argument('--device', type=str, default=None)
    parser.add_argument('--lamda', type=float, default=None)
    parser.add_argument('--train_metad', type=str2bool, default=None)
    parser.add_argument('--freq_rb', type=int, default=None)
    parser.add_argument('--freq_md', type=int, default=None)
    parser.add_argument('--log_reward_threshold', type=float, default=None)
    parser.add_argument('--epsilon', type=float, default=None)
    parser.add_argument('--lr', type=float, default=None)
    parser.add_argument('--optimizer', type=str, default=None)
    parser.add_argument('--noise_profile', type=str, default=None)
    parser.add_argument('--lr_schedule', type=str, default=None)
    parser.add_argument('--n_threads', type=int, default=None)
    parser.add_argument('--n_processes', type=int, default=None)
    parser.add_argument('--repeats', type=int, default=None)
    parser.add_argument('--hidden_dim', type=int, default=None)
    parser.add_argument('--hidden_layers', type=int, default=None)
    parser.add_argument('--mixture_dimension', type=int, default=None)
    parser.add_argument('--seed', type=int, default=None)
    parser.add_argument('--replay_buffer_type', type=str, default=None)
    parser.add_argument('--reward_only_replay_buffer', type=str2bool, default=None)
    parser.add_argument('--metad_epsilon', type=float, default=None)
    parser.add_argument('--metad_concentration', type=float, default=None)
    parser.add_argument('--metad_gamma', type=float, default=None)
    parser.add_argument('--metad_w', type=float, default=None)
    parser.add_argument('--min_log_concentration', type=float, default=None)
    parser.add_argument('--thompson_sampling', type=str2bool, default=None)
    parser.add_argument('--local_search', type=str2bool, default=None)
    parser.add_argument('--local_search_K', type=int, default=None)
    parser.add_argument('--nested_sampling', type=str2bool, default=None)
    parser.add_argument('--max_impulse_std', type=float, default=None)
    parser.add_argument('--max_policy_std', type=float, default=None)
    parser.add_argument('--edge_size', type=float, default=None)
    args = parser.parse_args()
    
    # Load the config file json file
    if args.env == 'pendulum':
        with open('./src/config/pendulum_environment.json', 'r') as f:
            config = json.load(f)
    elif args.env == 'line':
        with open('./src/config/line_environment.json', 'r') as f:
            config = json.load(f)
    elif args.env == 'alanine':
        with open('./src/config/alanine_environment.json', 'r') as f:
            config = json.load(f)
    elif args.env == 'grid':
        with open('./src/config/grid_environment.json', 'r') as f:
            config = json.load(f)
    elif args.env == 'hypergrid3':
        with open('./src/config/hypergrid3_environment.json', 'r') as f:
            config = json.load(f)
    elif args.env == 'hypergrid4':
        with open('./src/config/hypergrid4_environment.json', 'r') as f:
            config = json.load(f)
    else:
        raise ValueError("Invalid environment. Choose between pendulum, grid, line and alanine.")

    def update_config_if_provided(config, key, value):
        if value is not None:
            keys = key.split('.')
            d = config
            for k in keys[:-1]:
                d = d.setdefault(k, {})
            d[keys[-1]] = value

    update_config_if_provided(config, 'exp_name', args.exp_name)
    update_config_if_provided(config, 'n_iterations', args.n_iterations)
    update_config_if_provided(config, 'batch_size', args.bs)
    update_config_if_provided(config, 'gfn.trajectory_length', args.traj_length)
    update_config_if_provided(config, 'device', args.device)
    update_config_if_provided(config, 'n_threads', args.n_threads)
    update_config_if_provided(config, 'n_processes', args.n_processes)
    update_config_if_provided(config, 'repeats', args.repeats)
    update_config_if_provided(config, 'gfn.loss', args.loss)
    update_config_if_provided(config, 'freq_rb', args.freq_rb)
    update_config_if_provided(config, 'freq_md', args.freq_md)
    update_config_if_provided(config, 'replay_buffer.reward_threshold', args.log_reward_threshold)
    update_config_if_provided(config, 'gfn.epsilon_noisy.epsilon', args.epsilon)
    update_config_if_provided(config, 'gfn.lr_model', args.lr)
    update_config_if_provided(config, 'gfn.optimizer', args.optimizer)
    update_config_if_provided(config, 'gfn.noise_exploration.noise_profile', args.noise_profile)
    update_config_if_provided(config, 'gfn.lr_schedule', args.lr_schedule)
    update_config_if_provided(config, 'gfn.thompson_sampling', args.thompson_sampling)
    update_config_if_provided(config, 'gfn.hidden_dim', args.hidden_dim)
    update_config_if_provided(config, 'gfn.n_hidden_layers', args.hidden_layers)
    update_config_if_provided(config, 'env.mixture_dim', args.mixture_dimension)
    update_config_if_provided(config, 'replay_buffer.sampling_method', args.replay_buffer_type)
    update_config_if_provided(config, 'env.min_log_concentration', args.min_log_concentration)
    update_config_if_provided(config, 'seed', args.seed)
    update_config_if_provided(config, 'replay_buffer.reward_only', args.reward_only_replay_buffer)
    update_config_if_provided(config, 'metad.train', args.train_metad)
    update_config_if_provided(config, 'metad.active', args.metad)
    update_config_if_provided(config, 'gfn.noise_exploration.active', args.noise_exploration)
    update_config_if_provided(config, 'replay_buffer.active', args.replay_buffer)
    update_config_if_provided(config, 'gfn.tie_weights', args.tie_weights)
    update_config_if_provided(config, 'metad.epsilon', args.metad_epsilon)
    update_config_if_provided(config, 'metad.concentration', args.metad_concentration)
    update_config_if_provided(config, 'metad.gamma', args.metad_gamma)
    update_config_if_provided(config, 'metad.w', args.metad_w)
    update_config_if_provided(config, 'env.max_impulse_std', args.max_impulse_std)
    update_config_if_provided(config, 'env.max_policy_std', args.max_policy_std)

    update_config_if_provided(config, 'gfn.local_search', args.local_search)
    update_config_if_provided(config, 'gfn.local_search_K', args.local_search_K)

    update_config_if_provided(config, 'gfn.nested_sampling', args.nested_sampling)

    if args.env == 'grid' or args.env == 'hypergrid':
        update_config_if_provided(config, 'env.edge_size', args.edge_size)

    assert config["gfn"]["loss"] in ['TB', 'STB', 'DB'], "Invalid loss function. Choose between TB and STB"

    exp = Experiment(config)
    exp.train()




